{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "import numpy as np\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mpl.style.use('ggplot')\n",
    "mpl.rcParams.update({'font.size': 12})\n",
    "mpl.rcParams.update({'xtick.labelsize': 'x-large'})\n",
    "mpl.rcParams.update({'xtick.major.size': '0'})\n",
    "mpl.rcParams.update({'ytick.labelsize': 'x-large'})\n",
    "mpl.rcParams.update({'ytick.major.size': '0'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# https://stackoverflow.com/questions/18386210/annotating-ranges-of-data-in-matplotlib\n",
    "def draw_brace(ax, xspan, yoffset, text):\n",
    "    \"\"\"Draws an annotated brace on the axes.\"\"\"\n",
    "    xmin, xmax = xspan\n",
    "    xspan = xmax - xmin\n",
    "    ax_xmin, ax_xmax = ax.get_xlim()\n",
    "    xax_span = ax_xmax - ax_xmin\n",
    "    ymin, ymax = ax.get_ylim()\n",
    "    yspan = ymax - ymin\n",
    "    resolution = int(xspan/xax_span*100)*2+1 # guaranteed uneven\n",
    "    beta = 300./xax_span # the higher this is, the smaller the radius\n",
    "\n",
    "    x = np.linspace(xmin, xmax, resolution)\n",
    "    x_half = x[:int(resolution/2+1)]\n",
    "    y_half_brace = (1/(1.+np.exp(-beta*(x_half-x_half[0])))\n",
    "                    + 1/(1.+np.exp(-beta*(x_half-x_half[-1]))))\n",
    "    y = np.concatenate((y_half_brace, y_half_brace[-2::-1]))\n",
    "    y = yoffset + (.05*y - .02)*yspan # adjust vertical position\n",
    "\n",
    "    ax.autoscale(False)\n",
    "    ax.plot(x, y, color='black', lw=1)\n",
    "\n",
    "    ax.text((xmax+xmin)/2., yoffset+.07*yspan, text, ha='center', va='bottom', fontdict={'family': 'monospace'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def wall_potential(x, y, sphere_radius=5.0, inside=True):\n",
    "    r = np.sqrt(x ** 2 + y ** 2)\n",
    "    if inside:\n",
    "        r = sphere_radius - r\n",
    "    else:\n",
    "        r -= sphere_radius\n",
    "\n",
    "    def gauss_potential(r, epsilon=1.0, sigma=1.0):\n",
    "        arg = -0.5 * ((r / sigma) ** 2)\n",
    "        return epsilon * np.exp(arg)\n",
    "    \n",
    "    return np.where(r >= 0, gauss_potential(r), 0)\n",
    "\n",
    "spacing = np.linspace(-5, 5, 1000)\n",
    "x, y = np.meshgrid(spacing, spacing)\n",
    "\n",
    "fig = plt.Figure(figsize=(7, 4.32624056), dpi=100)\n",
    "ax = fig.add_subplot()\n",
    "img = ax.imshow(wall_potential(x, y).reshape((-1, 1000)), extent=[-5, 5, -5, 5])\n",
    "ax.set_xlabel('x')\n",
    "ax.set_ylabel('y')\n",
    "divider = make_axes_locatable(ax)\n",
    "cb = fig.colorbar(img, ax=ax, orientation=\"vertical\", fraction=0.05, pad=-0.3)\n",
    "cb.set_label(\"Potential Energy\")\n",
    "fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig.savefig('../md-wall-potential.svg', bbox_inches='tight', facecolor=(1, 1, 1, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "spacing = np.linspace(-7.5, 7.5, 2000)\n",
    "x, y = np.meshgrid(spacing, spacing)\n",
    "\n",
    "fig = plt.Figure(figsize=(7, 4.32624056), dpi=100)\n",
    "ax = fig.add_subplot()\n",
    "img = ax.imshow(wall_potential(x, y, inside=False).reshape((-1, 2000)), extent=[-7.5, 7.5, -7.5, 7.5])\n",
    "ax.set_xlabel('x')\n",
    "ax.set_ylabel('y')\n",
    "divider = make_axes_locatable(ax)\n",
    "cb = fig.colorbar(img, ax=ax, orientation=\"vertical\", fraction=0.05, pad=-0.3)\n",
    "cb.set_label(\"Potential Energy\")\n",
    "fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig.savefig('../md-wall-potential-outside.svg', bbox_inches='tight', facecolor=(1, 1, 1, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def lj_wall_energy(r, r_extrap, r_cut, sigma=1.0, epsilon=1.0):\n",
    "    def lj(r):\n",
    "        a = (sigma / r) ** 6\n",
    "        return 4 * epsilon * (a ** 2 - a)\n",
    "    \n",
    "    def lj_force(r):\n",
    "        lj1 = 4 * epsilon * (sigma ** 12)\n",
    "        lj2 = 4 * epsilon * (sigma ** 6)\n",
    "        r2_inv = 1 / (r * r)\n",
    "        r6_inv = r2_inv * r2_inv * r2_inv\n",
    "        return r2_inv * r6_inv * (12.0 * lj1 * r6_inv - 6.0 * lj2)\n",
    "    \n",
    "    if r_extrap == 0:\n",
    "        return lj(r)\n",
    "    \n",
    "    v_extrap = lj(r_extrap)\n",
    "    f_extrap = lj_force(r_extrap)\n",
    "    \n",
    "    return np.where(r <= r_extrap, v_extrap + (r_extrap - r) * f_extrap, lj(r))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "r = np.linspace(-1, 3, 1000)\n",
    "r_without = np.linspace(0.89, 3, 1000)\n",
    "\n",
    "energy = lj_wall_energy(r, 1.1, 3.0)\n",
    "energy_base = lj_wall_energy(r_without, 0.0, 3.0)\n",
    "\n",
    "fig = plt.Figure(figsize=(7, 4.32624056), dpi=100)\n",
    "ax = fig.add_subplot()\n",
    "\n",
    "ax.plot(r, energy, label=r\"Extrapolated $r_{extrap} = 1.1$\")\n",
    "ax.plot(r_without, energy_base, label=\"Standard\")\n",
    "ax.set_xlabel(\"$r$\")\n",
    "ax.set_ylabel(\"$V$\")\n",
    "ax.legend()\n",
    "fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig.savefig('../md-wall-extrapolate.svg', bbox_inches='tight', facecolor=(1, 1, 1, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
